Libraries required for this analysis

knitr::opts_chunk$set(fig.align="center") 
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr) 
library(ggplot2)
library(magrittr)  
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)

theme_set(theme_light())

Read in and clean data

accuracy_data = read.csv('processed_accuracy_split.csv')
accuracy_data$oracle = as.factor(accuracy_data$oracle)
accuracy_data$search = as.factor(accuracy_data$search)
accuracy_data$dataset = as.factor(accuracy_data$dataset)

models <- list()

draw_data <- list()

search_differences <- list()
oracle_differences <- list()

seed = 12

In our experiement, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s accuracy on two tasks: Find Extremum and Retrieve Value.

Given a search algorithm (bsf or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict a user’s chance of answering the Find Extremum task and the Retrieve Value tasks correctly. In addition, we would like to know if the choice of search algorithm and oracle has any meaninful impact on a user’s accuracy for these two tasks. ## Find Extremum: Building a Model for Accuracy Analysis

data_find_extremum <- subset(accuracy_data, task == "1. Find Extremum")
models$find_extremum <- brm(accuracy ~ oracle*search+dataset, 
                    data = data_find_extremum,
                    prior = c(prior(normal(1, .05), class = Intercept)),
                    family = bernoulli(link = "logit"),
                    warmup = 500, 
                    iter = 3000, 
                    chains = 2, 
                    cores=2,
                    seed=seed,
                    file = "acc_find_extremum"
                    )

Find Extremum: Diagnostics + Model Evaluation

In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.

summary(models$find_extremum)
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: accuracy ~ oracle * search + dataset 
##    Data: data_find_extremum (Number of observations: 59) 
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
##          total post-warmup samples = 5000
## 
## Population-Level Effects: 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept                  0.60      0.61    -0.58     1.82 1.00     3020
## oracledziban               0.81      0.87    -0.85     2.51 1.00     2751
## searchdfs                  0.40      0.85    -1.24     2.13 1.00     2785
## datasetmovies              0.20      0.61    -0.98     1.35 1.00     4338
## oracledziban:searchdfs    -1.17      1.21    -3.48     1.18 1.00     2307
##                        Tail_ESS
## Intercept                  3497
## oracledziban               3517
## searchdfs                  2893
## datasetmovies              3393
## oracledziban:searchdfs     2822
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

Trace plots help us check whether there is evidence of non-convergence for model.

plot(models$find_extremum)

In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differenciating the effect of such parameters).

pairs(models$find_extremum)

A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).

pred <- predict(models$find_extremum, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(data_find_extremum, accuracy)) 
confusion_matrix
##     
## pred  0  1
##    1  5 54

Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.

draw_data$find_extremum <- data_find_extremum %>%
  add_fitted_draws(models$find_extremum, seed = seed, re_formula = NA) %>%
  group_by(search, oracle, dataset, .draw)

draw_data$find_extremum$task <- "1. Find Extremum"
draw_data$find_extremum$condition <- paste(draw_data$find_extremum$oracle, draw_data$find_extremum$search, sep="_")

find_extremum_plot <- draw_data$find_extremum %>% ggplot(aes(
    x = .value,
    y = condition,
    fill = dataset,
    alpha = 0.5
  )) + stat_halfeye(.width = c(.95, .5)) +
    labs(x = "Predicted Accuracy (p_correct)", y = "Oracle/Search Combination") 

find_extremum_plot

Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.

fit_info <-  draw_data$find_extremum %>% group_by(search, oracle, dataset) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 16 x 9
## # Groups:   search, oracle [4]
##    search oracle    dataset     .value .lower .upper .width .point .interval
##    <fct>  <fct>     <fct>        <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
##  1 bfs    compassql birdstrikes  0.636  0.359  0.861   0.95 mean   qi       
##  2 bfs    compassql movies       0.678  0.423  0.882   0.95 mean   qi       
##  3 bfs    dziban    birdstrikes  0.787  0.552  0.938   0.95 mean   qi       
##  4 bfs    dziban    movies       0.817  0.605  0.949   0.95 mean   qi       
##  5 dfs    compassql birdstrikes  0.717  0.454  0.902   0.95 mean   qi       
##  6 dfs    compassql movies       0.753  0.519  0.922   0.95 mean   qi       
##  7 dfs    dziban    birdstrikes  0.645  0.380  0.865   0.95 mean   qi       
##  8 dfs    dziban    movies       0.688  0.441  0.879   0.95 mean   qi       
##  9 bfs    compassql birdstrikes  0.636  0.547  0.732   0.5  mean   qi       
## 10 bfs    compassql movies       0.678  0.600  0.768   0.5  mean   qi       
## 11 bfs    dziban    birdstrikes  0.787  0.730  0.862   0.5  mean   qi       
## 12 bfs    dziban    movies       0.817  0.766  0.882   0.5  mean   qi       
## 13 dfs    compassql birdstrikes  0.717  0.644  0.802   0.5  mean   qi       
## 14 dfs    compassql movies       0.753  0.685  0.832   0.5  mean   qi       
## 15 dfs    dziban    birdstrikes  0.645  0.557  0.739   0.5  mean   qi       
## 16 dfs    dziban    movies       0.688  0.613  0.773   0.5  mean   qi
## Saving 7 x 5 in image

Find Extremum: Differences Between Conditions

Next, we want to see if there is any significant difference between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).

Differences in search algorithms:

find_extremum_predictive_data  <- data_find_extremum %>%
    add_predicted_draws(models$find_extremum, seed = seed, re_formula = NA) %>%
    group_by(search, oracle, dataset, .draw)

search_differences$find_extremum <- find_extremum_predictive_data  %>%
    group_by(search, dataset, .draw) %>%
    summarize(accuracy = weighted.mean(.prediction)) %>%
    compare_levels(accuracy, by = search) %>%
    rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'search', 'dataset' (override with `.groups` argument)
search_differences$find_extremum$metric = "1. Find Extremum"

search_differences$find_extremum %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",search_differences$find_extremum[1,'search'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

search_differences$find_extremum %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups:   search [1]
##   search    dataset   difference_in_accur… .lower .upper .width .point .interval
##   <chr>     <fct>                    <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dfs - bfs birdstri…              -0.0300 -0.429 0.357    0.95 mean   qi       
## 2 dfs - bfs movies                 -0.0328 -0.375 0.342    0.95 mean   qi       
## 3 dfs - bfs birdstri…              -0.0300 -0.143 0.143    0.5  mean   qi       
## 4 dfs - bfs movies                 -0.0328 -0.171 0.0833   0.5  mean   qi

Differences in oracle:

oracle_differences$find_extremum <- find_extremum_predictive_data  %>%
    group_by(oracle, dataset, .draw) %>%
    summarize(accuracy = weighted.mean(.prediction)) %>%
    compare_levels(accuracy, by = oracle) %>%
    rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'oracle', 'dataset' (override with `.groups` argument)
oracle_differences$find_extremum$metric = "1. Find Extremum"

oracle_differences$find_extremum %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",oracle_differences$find_extremum[1,'oracle'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

oracle_differences$find_extremum %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups:   oracle [1]
##   oracle      dataset  difference_in_acc…  .lower .upper .width .point .interval
##   <chr>       <fct>                 <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dziban - c… birdstr…             0.0412 -0.357   0.429   0.95 mean   qi       
## 2 dziban - c… movies               0.0326 -0.312   0.404   0.95 mean   qi       
## 3 dziban - c… birdstr…             0.0412 -0.0714  0.143   0.5  mean   qi       
## 4 dziban - c… movies               0.0326 -0.108   0.15    0.5  mean   qi

Retrieve Value: Building a Model for Accuracy Analysis

data_retrieve_value <- subset(accuracy_data, task == "2. Retrieve Value")
models$retrieve_value <- brm(accuracy ~ oracle*search+dataset, 
                    data = data_retrieve_value,
                    prior = c(prior(normal(1, .05), class = Intercept)),
                    family = bernoulli(link = "logit"),
                    warmup = 500, 
                    iter = 3000, 
                    chains = 2, 
                    cores=2,
                    seed=seed,
                    file = "acc_retrieve_value"
                    )

Retrieve Value: Diagnostics + Model Evaluation

In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.

summary(models$retrieve_value)
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: accuracy ~ oracle * search + dataset 
##    Data: data_retrieve_value (Number of observations: 59) 
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
##          total post-warmup samples = 5000
## 
## Population-Level Effects: 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept                  0.98      0.60    -0.15     2.19 1.00     3087
## oracledziban               0.40      0.84    -1.23     2.04 1.00     2815
## searchdfs                  0.83      0.89    -0.89     2.58 1.00     2933
## datasetmovies             -0.51      0.62    -1.76     0.71 1.00     3744
## oracledziban:searchdfs    -1.18      1.21    -3.53     1.21 1.00     2424
##                        Tail_ESS
## Intercept                  3083
## oracledziban               2994
## searchdfs                  3226
## datasetmovies              3214
## oracledziban:searchdfs     3159
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

Trace plots help us check whether there is evidence of non-convergence for model.

plot(models$retrieve_value)

In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differenciating the effect of such parameters).

pairs(models$retrieve_value)

A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).

pred <- predict(models$retrieve_value, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(data_retrieve_value, accuracy)) 
confusion_matrix
##     
## pred  0  1
##    1  5 54

Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.

draw_data$retrieve_value <- data_retrieve_value %>%
  add_fitted_draws(models$retrieve_value, seed = seed, re_formula = NA) %>%
  group_by(search, oracle, dataset, .draw)

draw_data$retrieve_value$task <- "2. Retrieve Value"
draw_data$retrieve_value$condition <- paste(draw_data$retrieve_value$oracle, draw_data$retrieve_value$search, sep="_")

retrieve_value_plot <- draw_data$retrieve_value %>% ggplot(aes(
    x = .value,
    y = condition,
    fill = dataset,
    alpha = 0.5
  )) + stat_halfeye(.width = c(.95, .5)) +
    labs(x = "Predicted Accuracy (p_correct)", y = "Oracle/Search Combination") 

retrieve_value_plot

Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.

fit_info <-  draw_data$retrieve_value %>% group_by(search, oracle, dataset) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 16 x 9
## # Groups:   search, oracle [4]
##    search oracle    dataset     .value .lower .upper .width .point .interval
##    <fct>  <fct>     <fct>        <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
##  1 bfs    compassql birdstrikes  0.712  0.462  0.899   0.95 mean   qi       
##  2 bfs    compassql movies       0.605  0.325  0.844   0.95 mean   qi       
##  3 bfs    dziban    birdstrikes  0.780  0.547  0.935   0.95 mean   qi       
##  4 bfs    dziban    movies       0.689  0.426  0.887   0.95 mean   qi       
##  5 dfs    compassql birdstrikes  0.841  0.632  0.958   0.95 mean   qi       
##  6 dfs    compassql movies       0.766  0.514  0.929   0.95 mean   qi       
##  7 dfs    dziban    birdstrikes  0.720  0.461  0.905   0.95 mean   qi       
##  8 dfs    dziban    movies       0.615  0.340  0.843   0.95 mean   qi       
##  9 bfs    compassql birdstrikes  0.712  0.636  0.797   0.5  mean   qi       
## 10 bfs    compassql movies       0.605  0.516  0.704   0.5  mean   qi       
## 11 bfs    dziban    birdstrikes  0.780  0.721  0.857   0.5  mean   qi       
## 12 bfs    dziban    movies       0.689  0.612  0.778   0.5  mean   qi       
## 13 dfs    compassql birdstrikes  0.841  0.796  0.903   0.5  mean   qi       
## 14 dfs    compassql movies       0.766  0.702  0.847   0.5  mean   qi       
## 15 dfs    dziban    birdstrikes  0.720  0.648  0.804   0.5  mean   qi       
## 16 dfs    dziban    movies       0.615  0.527  0.711   0.5  mean   qi
## Saving 7 x 5 in image

Retrieve Value: Differences Between Conditions

Next, we want to see if there is any significant difference between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).

Differences in search algorithms:

retrieve_value_predictive_data  <- data_retrieve_value %>%
    add_predicted_draws(models$retrieve_value, seed = seed, re_formula = NA) %>%
    group_by(search, oracle, dataset, .draw)

search_differences$retrieve_value <- retrieve_value_predictive_data  %>%
    group_by(search, dataset, .draw) %>%
    summarize(accuracy = weighted.mean(.prediction)) %>%
    compare_levels(accuracy, by = search) %>%
    rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'search', 'dataset' (override with `.groups` argument)
search_differences$retrieve_value$metric = "2. Retrieve Value"

search_differences$retrieve_value %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",search_differences$retrieve_value[1,'search'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

search_differences$retrieve_value %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups:   search [1]
##   search   dataset    difference_in_accu…  .lower .upper .width .point .interval
##   <chr>    <fct>                    <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dfs - b… birdstrik…              0.0345 -0.287   0.429   0.95 mean   qi       
## 2 dfs - b… movies                  0.0387 -0.367   0.467   0.95 mean   qi       
## 3 dfs - b… birdstrik…              0.0345 -0.0714  0.143   0.5  mean   qi       
## 4 dfs - b… movies                  0.0387 -0.104   0.158   0.5  mean   qi

Differences in oracle:

oracle_differences$retrieve_value <- retrieve_value_predictive_data  %>%
    group_by(oracle, dataset, .draw) %>%
    summarize(accuracy = weighted.mean(.prediction)) %>%
    compare_levels(accuracy, by = oracle) %>%
    rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'oracle', 'dataset' (override with `.groups` argument)
oracle_differences$retrieve_value$metric = "2. Retrieve Value"

oracle_differences$retrieve_value %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",oracle_differences$retrieve_value[1,'oracle'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.

oracle_differences$retrieve_value %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups:   oracle [1]
##   oracle      dataset   difference_in_acc… .lower .upper .width .point .interval
##   <chr>       <fct>                  <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dziban - c… birdstri…            -0.0240 -0.357 0.357    0.95 mean   qi       
## 2 dziban - c… movies               -0.0444 -0.429 0.346    0.95 mean   qi       
## 3 dziban - c… birdstri…            -0.0240 -0.143 0.0714   0.5  mean   qi       
## 4 dziban - c… movies               -0.0444 -0.175 0.0875   0.5  mean   qi

Summary Plots

Putting the all of the plots for search algorithm differences on the same plot:

combined_search_differences <- rbind(search_differences$find_extremum, search_differences$retrieve_value)
search_differences_plot <- combined_search_differences %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",combined_search_differences[1,'search'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

search_differences_plot

search_intervals <- combined_search_differences %>% group_by(search, dataset, metric) %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
search_intervals
## # A tibble: 8 x 9
## # Groups:   search, dataset [2]
##   search  dataset metric difference_in_a…  .lower .upper .width .point .interval
##   <chr>   <fct>   <chr>             <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dfs - … birdst… 1. Fi…          -0.0300 -0.429  0.357    0.95 mean   qi       
## 2 dfs - … birdst… 2. Re…           0.0345 -0.287  0.429    0.95 mean   qi       
## 3 dfs - … movies  1. Fi…          -0.0328 -0.375  0.342    0.95 mean   qi       
## 4 dfs - … movies  2. Re…           0.0387 -0.367  0.467    0.95 mean   qi       
## 5 dfs - … birdst… 1. Fi…          -0.0300 -0.143  0.143    0.5  mean   qi       
## 6 dfs - … birdst… 2. Re…           0.0345 -0.0714 0.143    0.5  mean   qi       
## 7 dfs - … movies  1. Fi…          -0.0328 -0.171  0.0833   0.5  mean   qi       
## 8 dfs - … movies  2. Re…           0.0387 -0.104  0.158    0.5  mean   qi

Putting the all of the plots for oracle differences on the same plot:

combined_oracle_differences <- rbind(oracle_differences$find_extremum, oracle_differences$retrieve_value)
oracle_differences_plot <- combined_oracle_differences %>%
      ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
      xlab(paste0("Expected Difference in Accuracy (",combined_oracle_differences[1,'oracle'],")")) + 
      ylab("Task")+
      stat_halfeye(.width = c(.95, .5)) +
      geom_vline(xintercept = 0, linetype = "longdash") +
      theme_minimal() +
     facet_grid(. ~ dataset)

oracle_differences_plot

oracle_intervals <- combined_oracle_differences %>% group_by(oracle, dataset, metric) %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
oracle_intervals
## # A tibble: 8 x 9
## # Groups:   oracle, dataset [2]
##   oracle  dataset metric difference_in_a…  .lower .upper .width .point .interval
##   <chr>   <fct>   <chr>             <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 dziban… birdst… 1. Fi…           0.0412 -0.357  0.429    0.95 mean   qi       
## 2 dziban… birdst… 2. Re…          -0.0240 -0.357  0.357    0.95 mean   qi       
## 3 dziban… movies  1. Fi…           0.0326 -0.312  0.404    0.95 mean   qi       
## 4 dziban… movies  2. Re…          -0.0444 -0.429  0.346    0.95 mean   qi       
## 5 dziban… birdst… 1. Fi…           0.0412 -0.0714 0.143    0.5  mean   qi       
## 6 dziban… birdst… 2. Re…          -0.0240 -0.143  0.0714   0.5  mean   qi       
## 7 dziban… movies  1. Fi…           0.0326 -0.108  0.15     0.5  mean   qi       
## 8 dziban… movies  2. Re…          -0.0444 -0.175  0.0875   0.5  mean   qi